/*
 * @Author: gaoxinglong
 * @Date: 2022-11-16 14:24:20
 * @LastEditTime: 2022-12-01 20:35:04
 * @LastEditors: gaoxinglong
 */
#include <iostream>
#include <memory>
#include <regex>
#include <torch/script.h> // One-stop header.
#include <vector>
// #include "feature_pipeline.h"
#include "streaming_asr_on_device.h"
#include "wav.h"
#include <algorithm>
#include <sys/time.h>

std::vector<std::string> split_with_separators(const std::string &input,
                                               const std::string &reg_sep) {
  std::vector<std::string> result;
  std::regex re(reg_sep);
  std::sregex_token_iterator first{input.begin(), input.end(), re, -1}, last;
  return {first, last};
}

int64_t getCurrentTime() // 直接调用这个函数就行了，返回值最好是int64_t，long
                         // long应该也可以
{
  struct timeval tv;
  gettimeofday(&tv, NULL); // 该函数在sys/time.h头文件中
  return tv.tv_sec * 1000 + tv.tv_usec / 1000;
}

struct stream_engine_info {
  const char *asr_config{nullptr};
  int32_t asr_engine_id{0};
  std::vector<std::string> task_list{};
  std::string result_filename{};
};

int32_t stream_rec_worker(stream_engine_info *info) {
  StreamingAsrConfig config(info->asr_config);
  StreamingAsrOnDevice asr_engine(config);
  asr_engine.init();
  fprintf(stderr, "stream_rec_worker %d engine, task %d\n", info->asr_engine_id,
          info->task_list.size());
  int32_t num_sucess = 0;
  std::ofstream os(info->result_filename.c_str(), std::ios::out);
  for (size_t i = 0; i < info->task_list.size(); i++) {
    std::vector<std::string> key_and_value =
        split_with_separators(info->task_list[i], "\\s+");
    for (size_t k = 0; k < key_and_value.size(); k++) {
      std::cout << key_and_value[k] << ", size : " << key_and_value.size()
                << std::endl;
    }

    if (key_and_value.size() == 2) {
      WavReader wav_reader(key_and_value[1].c_str());
      const float *data = wav_reader.data();
      int num_sample = wav_reader.num_sample();
      std::cout << "getting " << num_sample << " samples" << std::endl;
      int offset = 0;
      std::string temp_result = "";
      int64_t send_samples = 4000;
      while (offset < num_sample) {
        if (num_sample - offset < send_samples) {
          std::cout << "This is the last sample" << std::endl;
          temp_result = asr_engine.ProcessBlock(
              data + offset,
              std::min(static_cast<int64_t>(send_samples),
                       static_cast<int64_t>(num_sample - offset)),
              true);
          std::cout << temp_result << std::endl;
          asr_engine.reset();
        } else {
          temp_result = asr_engine.ProcessBlock(
              data + offset,
              std::min(static_cast<int64_t>(send_samples),
                       static_cast<int64_t>(num_sample - offset)),
              false);
          std::cout << temp_result << std::endl;
        }
        offset += send_samples;
      }
      fprintf(stdout, "engine:%d, %s:%s\n", info->asr_engine_id,
              key_and_value[0].c_str(), temp_result.c_str());
      os << key_and_value[0] << "\t" << temp_result << std::endl;
      num_sucess++;
    }
  }
  os.close();

  return 0;
}

void streaming_decode_multi_thread(const char *decode_config,
                                   const char *wav_scp, int32_t num_threads) {
  // preparation for multi-thread data processing
  std::vector<std::string> task_list;
  std::string line;
  std::ifstream infile(wav_scp);
  fprintf(stdout, "start to read task list\n");
  while (std::getline(infile, line, '\n')) {
    if (!line.empty()) {
      task_list.emplace_back(line);
    }
  }
  infile.close();
  fprintf(stderr, "read rec tasks from %s, number:%d\n", wav_scp,
          task_list.size());
  // start to process multi-thread task_list
  int32_t tasks_every_workers = task_list.size() / num_threads + 1;
  std::vector<stream_engine_info> workers_infos(num_threads);
  for (int32_t i = 0; i < num_threads; i++) {
    workers_infos[i].asr_config = decode_config;
    workers_infos[i].asr_engine_id = i;
    int64_t start_index = i * tasks_every_workers;
    int64_t end_index =
        std::min(static_cast<int64_t>((i + 1) * tasks_every_workers),
                 static_cast<int64_t>(task_list.size()));
    workers_infos[i].task_list.resize(end_index - start_index);
    fprintf(stdout, "copy %d tasks to thread %d\n", end_index - start_index, i);
    std::copy(task_list.begin() + start_index, task_list.begin() + end_index,
              workers_infos[i].task_list.begin());
    workers_infos[i].result_filename = "exp/rec_" + std::to_string(i) + ".txt";
  }
  std::vector<std::thread> workers;
  for (int32_t i = 0; i < num_threads; i++) {
    workers.emplace_back(stream_rec_worker, &workers_infos[i]);
  }
  for (size_t i = 0; i < num_threads; i++) {
    workers[i].join();
  }
}

void streaming_decode_test(const char *decode_config,
                           const char *single_wav_fn) {
  StreamingAsrConfig asr_config(decode_config);
  StreamingAsrOnDevice asr(asr_config);
  asr.init();

  WavReader wav_reader(single_wav_fn);
  const float *data = wav_reader.data();
  int num_sample = wav_reader.num_sample();
  std::cout << "getting " << num_sample << " samples" << std::endl;
  int offset = 0;
  while (offset < num_sample) {
    if (num_sample - offset < 4000) {
      std::cout << "This is the last sample" << std::endl;
      std::string hpy = asr.ProcessBlock(
          data + offset, std::min(4000, num_sample - offset), true);
      std::cout << hpy << std::endl;
    } else {
      std::string hpy = asr.ProcessBlock(
          data + offset, std::min(4000, num_sample - offset), false);
      std::cout << hpy << std::endl;
    }
    offset += 4000;
  }

  asr.reset();
  offset = 0;
  while (offset < num_sample) {
    if (num_sample - offset < 4000) {
      std::string hpy = asr.ProcessBlock(
          data + offset, std::min(4000, num_sample - offset), true);
      std::cout << hpy << std::endl;
    } else {
      std::string hpy = asr.ProcessBlock(
          data + offset, std::min(4000, num_sample - offset), false);
      std::cout << hpy << std::endl;
    }
    offset += 4000;
  }
}

void input_once_test() {
  //[1]:get feature from byte
  std::string filename("/home/changhengyi/cpp_dev/assets/xs.byte");
  FILE *fp = fopen(filename.c_str(), "rb");
  float feat[702 * 80];
  fread(feat, 4, 702 * 80, fp);
  fclose(fp);

  torch::Tensor one_torch_feat = torch::ones({702, 80});
  for (size_t t = 0; t < 702; t++) {
    for (size_t d = 0; d < 80; d++) {
      one_torch_feat[t][d] = feat[t * 80 + d];
    }
  }
  //[2]:inference
  torch::jit::script::Module asr =
      torch::jit::load("/home/changhengyi/cpp_dev/assets/streaming_asr.script");
  asr.run_method("reset");
  int64_t st = getCurrentTime();
  auto output = asr.run_method("on_recieve", one_torch_feat, 700, false, false)
                    .toTensor();
  ;
  int64_t et = getCurrentTime();
  std::cout << output.sizes() << std::endl;
  for (int i = 0; i < 21; i++) {
    std::cout << output[i] << " ";
  }
  std::cout << std::endl;
  std::cout << "time cost:" << et - st << std::endl;
}

int main(int argc, char *argv[]) {
  if (argc < 4) {
    fprintf(stdout,
            "Usage:\n%s <android_asr.config>  <single_wav_fn> <mode,0,1, "
            "default:0, single_wav_fn>\n",
            argv[0]);
    exit(-1);
  }
  std::string config_path = argv[1], single_wav_fn = argv[2];
  int32_t mode = std::atoi(argv[3]);
  if (mode == 0) {
    streaming_decode_test(config_path.c_str(), single_wav_fn.c_str());
  } else if (mode > 0) {
    streaming_decode_multi_thread(config_path.c_str(), single_wav_fn.c_str(),
                                  mode);
  }

  return 0;
}